//
// Created by Xing on 2022/3/26.
//

#ifndef CAFFE_REDUCE_L2_LAYER_HPP
#define CAFFE_REDUCE_L2_LAYER_HPP

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {
    // onnx reducel2 <==> caffe
    template <typename Dtype>
    class ReduceL2Layer : public Layer<Dtype> {
    public:
        explicit ReduceL2Layer(const LayerParameter& param)
                : Layer<Dtype>(param) {}
        virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
        virtual void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);

        virtual inline const char* type() const { return "ReduceL2"; }
        virtual inline int ExactNumBottomBlobs() const { return 1; }
        virtual inline int ExactNumTopBlobs() const { return 1; }

    protected:
        virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
        virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top){ NOT_IMPLEMENTED; }
        virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
                                  const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ NOT_IMPLEMENTED; }
        virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
                                  const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){ NOT_IMPLEMENTED; }

        int axis_;
        int num_;
        bool keep_dim_;
    };

}  // namespace caffe

#endif //CAFFE_REDUCE_L2_LAYER_HPP
